{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b739c0e2",
   "metadata": {},
   "source": [
    "# Herdnet Overhead Image Detection Demo with PytorchWildlife"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1197e180",
   "metadata": {},
   "source": [
    "This tutorial guides you on how to use PyTorchWildlife for image detection. We will go through the process of setting up the environment, defining the detection model, as well as performing inference and saving the results in different ways.\n",
    "\n",
    "## Prerequisites\n",
    "Install PytorchWildlife running the following commands:\n",
    "```bash\n",
    "conda create -n pytorch_wildlife python=3.10 -y\n",
    "conda activate pytorch_wildlife\n",
    "pip install PytorchWildlife\n",
    "```\n",
    "Also, make sure you have a CUDA-capable GPU if you intend to run the model on a GPU. This notebook can also run on CPU.\n",
    "\n",
    "## Importing libraries\n",
    "First, we'll start by importing the necessary libraries and modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c44e7713",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from PytorchWildlife.models import detection as pw_detection\n",
    "from PytorchWildlife import utils as pw_utils\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6abd07b5",
   "metadata": {},
   "source": [
    "## Model Initialization\n",
    "We will initialize the HerdNet model for overhead image detection. This model is designed locate and count dense herds in overhead images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb25db43",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setting the device to use for computations ('cuda' indicates GPU)\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "# Model arguments and initialization\n",
    "detection_model = pw_detection.HerdNet(device=DEVICE)\n",
    "# If you want to use ennedi dataset weights, you can use the following line:\n",
    "# detection_model = pw_detection.HerdNet(device=DEVICE, version=\"ennedi\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e57dcca",
   "metadata": {},
   "source": [
    "## Single Image Detection\n",
    "We will first perform detection on a single image. Make sure to verify that you have the image in the specified path."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d730b20",
   "metadata": {},
   "outputs": [],
   "source": [
    "img_path = os.path.join(\".\",\"demo_data\",\"herdnet_imgs\",\"S_11_05_16_DSC01556.JPG\")\n",
    "results = detection_model.single_image_detection(img=img_path)\n",
    "pw_utils.save_detection_images_dots(results, os.path.join(\".\",\"herdnet_demo_output\"), overwrite=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e23329c",
   "metadata": {},
   "source": [
    "## Batch Image Detection\n",
    "Next, we'll demonstrate how to process multiple images in batches. This is useful when you have a large number of images and want to process them efficiently.\n",
    " For HerdNet using batch size 1 is enough because each image is divided into more patches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "561eff0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_path = os.path.join(\".\",\"demo_data\",\"herdnet_imgs\")\n",
    "results = detection_model.batch_image_detection(folder_path, batch_size=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb41830b",
   "metadata": {},
   "source": [
    "## Output Results\n",
    "PytorchWildlife allows to output detection results in multiple formats. Here are the examples:\n",
    "\n",
    "### 1. Annotated Images:\n",
    "This will output the images with dots (instead of bounding boxes) drawn around the detected animals. The images will be saved in the specified output directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f63310ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "pw_utils.save_detection_images_dots(results, \"herdnet_demo_batch_output\", folder_path, overwrite=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49be4063",
   "metadata": {},
   "source": [
    "### 2. JSON Format:\n",
    "This will output the detection results in JSON format, but for HerdNet instead of bounding boxes, we will have dots. The JSON file will be saved in the specified output directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b627280b",
   "metadata": {},
   "outputs": [],
   "source": [
    "pw_utils.save_detection_json_as_dots(results, os.path.join(\".\",\"herdnet_demo_batch_output.json\"),\n",
    "                             categories=detection_model.CLASS_NAMES,\n",
    "                             exclude_category_ids=[], # Category IDs can be found in the definition of each model.\n",
    "                             exclude_file_path=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4ee1d7b",
   "metadata": {},
   "source": [
    "### Copyright (c) Microsoft Corporation. All rights reserved.\n",
    "### Licensed under the MIT License."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
